/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "xla/tsl/profiler/convert/post_process_single_host_xplane.h"

#include <cstdint>
#include <string>
#include <vector>

#include "absl/container/flat_hash_set.h"
#include "xla/tsl/platform/types.h"
#include "xla/tsl/profiler/utils/timestamp_utils.h"
#include "xla/tsl/profiler/utils/xplane_schema.h"
#include "xla/tsl/profiler/utils/xplane_utils.h"
#include "xla/tsl/profiler/utils/xplane_visitor.h"
#include "tsl/profiler/protobuf/xplane.pb.h"

namespace tsl {
namespace profiler {
namespace {

// Collects all the existing line ids in the given list of planes.
absl::flat_hash_set<int64_t> GetOccupiedLineIds(
    std::vector<const XPlane*>& planes) {
  absl::flat_hash_set<int64_t> occupied_line_ids;
  for (const XPlane* plane : planes) {
    for (const XLine& line : plane->lines()) {
      occupied_line_ids.insert(line.id());
    }
  }
  return occupied_line_ids;
}

// Changes all the line ids in the plane whose id appear in the
// occupied_line_ids set to an unoccupied line id starting from
// target_line_id_start.
void ChangeOccupiedLineIds(XPlane* plane,
                           absl::flat_hash_set<int64_t>& occupied_line_ids,
                           int64_t target_line_id_start) {
  for (XLine& line : *plane->mutable_lines()) {
    if (occupied_line_ids.contains(line.id())) {
      while (occupied_line_ids.contains(target_line_id_start)) {
        ++target_line_id_start;
      }
      line.set_id(target_line_id_start++);
    }
    occupied_line_ids.insert(line.id());
  }
}

// Merges XPlanes generated by TraceMe, CUPTI API trace and Python tracer.
void MergeHostPlanesAndSortLines(tensorflow::profiler::XSpace* space) {
  std::vector<const XPlane*> additional_host_planes = FindPlanesWithNames(
      *space,
      {kTpuRuntimePlaneName, kCuptiDriverApiPlaneName, kPythonTracerPlaneName,
       kRoctracerApiPlaneName, kHostThreadsPlaneName});
  absl::flat_hash_set<int64_t> occupied_line_ids =
      GetOccupiedLineIds(additional_host_planes);
  tensorflow::profiler::XPlane* host_plane = space->add_planes();
  host_plane->set_name(std::string(kHostThreadsPlaneName));
  if (!additional_host_planes.empty()) {
    MergePlanes(additional_host_planes, host_plane);
    RemovePlanes(space, additional_host_planes);
  }

  // Merge the CUPTI NVTX plane into the host plane.
  static constexpr int64_t kNvtxLineIdStart = 1LL << 32;
  XPlane* nvtx_plane =
      FindMutablePlaneWithName(space, kCuptiActivityNvtxPlaneName);
  if (nvtx_plane != nullptr) {
    // Before merging, change the line ids which are shared by the CUPTI host
    // plane and the NVTX-CUPTI plane to an unoccupied line id. And make sure
    // the new line id is not occupied by any other plane already merged.
    ChangeOccupiedLineIds(nvtx_plane, occupied_line_ids, kNvtxLineIdStart);
    MergePlanes({nvtx_plane}, host_plane);
    RemovePlanes(space, {nvtx_plane});
  }

  // Sort the lines by name.
  SortXLinesBy(host_plane, XLinesComparatorByName());
}

}  // namespace

void PostProcessSingleHostXSpace(tensorflow::profiler::XSpace* space,
                                 uint64_t start_time_ns,
                                 uint64_t stop_time_ns) {
  VLOG(3) << "Post processing local profiler XSpace.";
  // Post processing the collected XSpace without hold profiler lock.
  // 1. Merge all host planes and sorts lines by name.
  MergeHostPlanesAndSortLines(space);
  // 2. Normalize all timestamps by shifting timeline to profiling start time.
  // NOTE: this have to be done before sorting XSpace due to timestamp overflow.
  NormalizeTimestamps(space, start_time_ns);
  // 3. Add information regarding profiling start_time_ns_ and stop_time_ns_ to
  // taskEnv.
  SetSessionTimestamps(start_time_ns, stop_time_ns, *space);
  // 4. Sort each plane of the XSpace
  SortXSpace(space);
}

}  // namespace profiler
}  // namespace tsl
